#!/bin/bash

RUN_SCRIPT="run_draw2Img.sh"
software_dir="draw2Img"
venv_name="amdgpu"

if [ ! -d "$software_dir" ]; then
    git clone https://github.com/GradientSurfer/Draw2Img $software_dir
fi

cd $software_dir

if [ ! -d "$venv_name" ]; then
    python3 -m venv $venv_name
fi

source $venv_name/bin/activate

pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.0
pip install .

echo "建立启动脚本"
if [ ! -f "$RUN_SCRIPT" ]; then
    cat > $RUN_SCRIPT <<EOF
#!/bin/bash
source $venv_name/bin/activate
export HSA_OVERRIDE_GFX_VERSION=11.0.0
python draw2img/main.py
EOF
    chmod +x $RUN_SCRIPT
fi

echo "安装完成"